// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include <vector>
#include "single_layer_tests/experimental_detectron_generate_proposals_single_image.hpp"
#include "common_test_utils/ov_tensor_utils.hpp"

using namespace ov::test;
using namespace ov::test::subgraph;

namespace {

const std::vector<float> min_size = { 0.0f, 0.1f };
const std::vector<float> nms_threshold = { 0.7f };
const std::vector<int64_t> post_nms_count = { 6 };
const std::vector<int64_t> pre_nms_count = { 14, 1000 };

template <typename T>
const std::vector<std::pair<std::string, std::vector<ov::Tensor>>> getInputTensors() {
    std::vector<std::pair<std::string, std::vector<ov::Tensor>>> input_tensors = {
            {
                    "empty",
                    {
                            // 3
                            ov::test::utils::create_tensor<T>(ov::element::from<T>(), ov::Shape{3},
                                                              std::vector<T>{1.0f, 1.0f, 1.0f}),
                            // 36 x 4 = 144
                            ov::test::utils::create_tensor<T>(ov::element::from<T>(), ov::Shape{36, 4}, std::vector<T>{
                                    1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
                                    1.0f, 1.0f,
                                    1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
                                    1.0f, 1.0f,
                                    1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
                                    1.0f, 1.0f,
                                    1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
                                    1.0f, 1.0f,

                                    1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
                                    1.0f, 1.0f,
                                    1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
                                    1.0f, 1.0f,
                                    1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
                                    1.0f, 1.0f,
                                    1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
                                    1.0f, 1.0f,

                                    1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
                                    1.0f, 1.0f}),
                            // 12 x 2 x 6 = 144float
                            ov::test::utils::create_tensor<T>(ov::element::from<T>(), ov::Shape{12, 2, 6},
                                                              std::vector<T>{
                                                                      1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
                                                                      1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
                                                                      1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
                                                                      1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
                                                                      1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
                                                                      1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
                                                                      1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
                                                                      1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,

                                                                      1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
                                                                      1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
                                                                      1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
                                                                      1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
                                                                      1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
                                                                      1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
                                                                      1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
                                                                      1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,

                                                                      1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
                                                                      1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}),
                            // {3 x 2 x 6} = 36
                            ov::test::utils::create_tensor<T>(ov::element::from<T>(), ov::Shape{3, 2, 6},
                                                              std::vector<T>{
                                                                      5.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
                                                                      1.0f, 1.0f, 1.0f, 1.0f,
                                                                      1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
                                                                      4.0f, 1.0f, 1.0f, 1.0f,
                                                                      1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
                                                                      1.0f, 1.0f, 8.0f, 1.0f})
                    }
            },
            {
                    "filled",
                    {
                            ov::test::utils::create_tensor<T>(ov::element::from<T>(), ov::Shape{3},
                                                              std::vector<T>{150.0, 150.0, 1.0}),
                            ov::test::utils::create_tensor<T>(ov::element::from<T>(), ov::Shape{36, 4}, std::vector<T>{
                                    12.0, 68.0, 102.0, 123.0, 46.0, 80.0, 79.0, 128.0, 33.0, 71.0, 127.0, 86.0, 33.0,
                                    56.0, 150.0, 73.0,
                                    5.0, 41.0, 93.0, 150.0, 74.0, 66.0, 106.0, 115.0, 17.0, 37.0, 87.0, 150.0, 31.0,
                                    27.0, 150.0, 39.0,
                                    29.0, 23.0, 112.0, 123.0, 41.0, 37.0, 103.0, 150.0, 8.0, 46.0, 98.0, 111.0, 7.0,
                                    69.0, 114.0, 150.0,
                                    70.0, 21.0, 150.0, 125.0, 54.0, 19.0, 132.0, 68.0, 62.0, 8.0, 150.0, 101.0, 57.0,
                                    81.0, 150.0, 97.0,
                                    79.0, 29.0, 109.0, 130.0, 12.0, 63.0, 100.0, 150.0, 17.0, 33.0, 113.0, 150.0, 90.0,
                                    78.0, 150.0, 111.0,
                                    47.0, 68.0, 150.0, 71.0, 66.0, 103.0, 111.0, 150.0, 4.0, 17.0, 112.0, 94.0, 12.0,
                                    8.0, 119.0, 98.0,
                                    54.0, 56.0, 120.0, 150.0, 56.0, 29.0, 150.0, 31.0, 42.0, 3.0, 139.0, 92.0, 41.0,
                                    65.0, 150.0, 130.0,
                                    49.0, 13.0, 143.0, 30.0, 40.0, 60.0, 150.0, 150.0, 23.0, 73.0, 24.0, 115.0, 56.0,
                                    84.0, 107.0, 108.0,
                                    63.0, 8.0, 142.0, 125.0, 78.0, 37.0, 93.0, 144.0, 40.0, 34.0, 150.0, 46.0, 30.0,
                                    21.0, 150.0, 120.0}),
                            ov::test::utils::create_tensor<T>(ov::element::from<T>(), ov::Shape{12, 2, 6},
                                                              std::vector<T>{
                                                                      9.062256, 10.883133, 9.8441105, 12.694285,
                                                                      0.41781136, 8.749107, 14.990341, 6.587644,
                                                                      1.4206103,
                                                                      13.299262, 12.432549, 2.736371, 0.22732796,
                                                                      6.3361835, 12.268727, 2.1009045, 4.771589,
                                                                      2.5131326,
                                                                      5.610736, 9.3604145, 4.27379, 8.317948,
                                                                      0.60510135, 6.7446275, 1.0207708, 1.1352817,
                                                                      1.5785321,
                                                                      1.718335, 1.8093798, 0.99247587, 1.3233583,
                                                                      1.7432803, 1.8534478, 1.2593061, 1.7394226,
                                                                      1.7686696,
                                                                      1.647999, 1.7611449, 1.3119122, 0.03007332,
                                                                      1.1106564, 0.55669737, 0.2546148, 1.9181818,
                                                                      0.7134989,
                                                                      2.0407224, 1.7211134, 1.8565536, 14.562747,
                                                                      2.8786168, 0.5927796, 0.2064463, 7.6794515,
                                                                      8.672126,
                                                                      10.139171, 8.002429, 7.002932, 12.6314945,
                                                                      10.550842, 0.15784842, 0.3194304, 10.752157,
                                                                      3.709805,
                                                                      11.628928, 0.7136225, 14.619964, 15.177284,
                                                                      2.2824087, 15.381494, 0.16618137, 7.507227,
                                                                      11.173228,
                                                                      0.4923559, 1.8227729, 1.4749299, 1.7833921,
                                                                      1.2363617, -0.23659119, 1.5737582, 1.779316,
                                                                      1.9828427,
                                                                      1.0482665, 1.4900246, 1.3563544, 1.5341306,
                                                                      0.7634312, 4.6216766e-05, 1.6161222, 1.7512476,
                                                                      1.9363779,
                                                                      0.9195784, 1.4906164, -0.03244795, 0.681073,
                                                                      0.6192401, 1.8033613, 14.146055, 3.4043705,
                                                                      15.292292,
                                                                      3.5295358, 11.138999, 9.952057, 5.633434,
                                                                      12.114562, 9.427372, 12.384038, 9.583308,
                                                                      8.427233,
                                                                      15.293704, 3.288159, 11.64898, 9.350885,
                                                                      2.0037227, 13.523184, 4.4176426, 6.1057625,
                                                                      14.400079,
                                                                      8.248259, 11.815807, 15.713364, 1.0023532,
                                                                      1.3203261, 1.7100681, 0.7407832, 1.09448,
                                                                      1.7188418,
                                                                      1.4412547, 1.4862992, 0.74790007, 0.31571656,
                                                                      0.6398838, 2.0236106, 1.1869069, 1.7265586,
                                                                      1.2624544,
                                                                      0.09934269, 1.3508598, 0.85212964, -0.38968498,
                                                                      1.7059708, 1.6533034, 1.7400402, 1.8123854,
                                                                      -0.43063712}),
                            ov::test::utils::create_tensor<T>(ov::element::from<T>(), ov::Shape{3, 2, 6},
                                                              std::vector<T>{
                                                                      0.7719922, 0.35906568, 0.29054508, 0.18124384,
                                                                      0.5604661, 0.84750974, 0.98948747, 0.009793862,
                                                                      0.7184191,
                                                                      0.5560748, 0.6952493, 0.6732593, 0.3306898,
                                                                      0.6790913, 0.41128764, 0.34593266, 0.94296855,
                                                                      0.7348507,
                                                                      0.24478768, 0.94024557, 0.05405676, 0.06466125,
                                                                      0.36244348, 0.07942984, 0.10619422, 0.09412837,
                                                                      0.9053611,
                                                                      0.22870538, 0.9237487, 0.20986171, 0.5067282,
                                                                      0.29709867, 0.53138554, 0.189101, 0.4786443,
                                                                      0.88421875}),
                    }
            }
    };
    return input_tensors;
}

const std::vector<std::vector<InputShape>> input_shape = {
    // im_info / anchors / deltas / scores
    static_shapes_to_test_representation({{3}, {36, 4}, {12, 2, 6}, {3, 2, 6}}),
};

INSTANTIATE_TEST_SUITE_P(
    smoke_ExperimentalDetectronGenerateProposalsSingleImageLayerTest_f16,
    ExperimentalDetectronGenerateProposalsSingleImageLayerTest,
    ::testing::Combine(
        ::testing::ValuesIn(input_shape),
        ::testing::ValuesIn(min_size),
        ::testing::ValuesIn(nms_threshold),
        ::testing::ValuesIn(post_nms_count),
        ::testing::ValuesIn(pre_nms_count),
        ::testing::ValuesIn(getInputTensors<ov::float16>()),
        ::testing::ValuesIn({ov::element::Type_t::f16}),
        ::testing::Values(CommonTestUtils::DEVICE_GPU)),
    ExperimentalDetectronGenerateProposalsSingleImageLayerTest::getTestCaseName);

INSTANTIATE_TEST_SUITE_P(
        smoke_ExperimentalDetectronGenerateProposalsSingleImageLayerTest_f32,
        ExperimentalDetectronGenerateProposalsSingleImageLayerTest,
        ::testing::Combine(
                ::testing::ValuesIn(input_shape),
                ::testing::ValuesIn(min_size),
                ::testing::ValuesIn(nms_threshold),
                ::testing::ValuesIn(post_nms_count),
                ::testing::ValuesIn(pre_nms_count),
                ::testing::ValuesIn(getInputTensors<float>()),
                ::testing::ValuesIn({ov::element::Type_t::f32}),
                ::testing::Values(CommonTestUtils::DEVICE_GPU)),
        ExperimentalDetectronGenerateProposalsSingleImageLayerTest::getTestCaseName);

} // namespace
